{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5755bc04",
   "metadata": {},
   "source": [
    "# Decision Tree Regression Using Concrete ML\n",
    "\n",
    "In this tutorial, we show how to create, train and evaluate a decision tree regression model using Concrete ML library.\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "2c256087-c16a-4249-9c90-3f4863938385",
   "metadata": {},
   "source": [
    "### Introducing Concrete ML\n",
    "\n",
    "> Concrete ML is an open-source, privacy-preserving, machine learning inference framework based on fully homomorphic encryption (FHE).\n",
    "> It enables data scientists without any prior knowledge of cryptography to automatically turn machine learning models into their FHE equivalent,using familiar APIs from Scikit-learn and PyTorch.\n",
    "> <cite>&mdash; [Zama documentation](../README.md)</cite>\n",
    "\n",
    "This tutorial does not require a deep understanding of the technology behind concrete-ML.\n",
    "Nonetheless, newcomers might be interested in reading introductory sections of the official documentation such as:\n",
    "\n",
    "- [What is Concrete ML](../README.md)\n",
    "- [Key Concepts](../getting-started/concepts.md)\n",
    "\n",
    "In the tutorial, we will be using the following terminology:\n",
    "\n",
    "- plaintext: data unprotected, visible to anyone having access to it.\n",
    "- ciphertext: ciphered data, need to know the secret in order to decipher the data.\n",
    "\n",
    "Conventional models work with plaintext, where ConcreteML can work directly with ciphertext.\n",
    "Privacy is preserved as the model does not know the secret and thus cannot decipher the data. \n",
    "Yet it outputs a ciphered estimate for the owner of the secret."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "777377ce",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "We will be using California housing prices dataset available in the scikit-learn package.\n",
    "The data was derived from the 1990 U.S. census, using one row per census block group, where\n",
    "a block group typically has a population of 600 to 3,000 people.\n",
    "\n",
    "- Target Variable : Median house value in $100'000\n",
    "- Number of Observations :  20'640\n",
    "- Number of Attributes : 8 numeric, predictive attributes\n",
    "- Attribute Information :\n",
    "      - MedInc        median income in block group\n",
    "      - HouseAge      median house age in block group\n",
    "      - AveRooms      average number of rooms per household\n",
    "      - AveBedrms     average number of bedrooms per household\n",
    "      - Population    block group population\n",
    "      - AveOccup      average number of household members\n",
    "      - Latitude      block group latitude\n",
    "      - Longitude     block group longitude\n",
    "\n",
    "- Missing Attribute Values : None\n",
    "\n",
    "To speedup computations, we will use a bootstrapped resample of about 5'000 observations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0e3ca1b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using ConcreteML version 1.6.0\n",
      "With Python version 3.8.18 (default, Jul 16 2024, 19:04:03) \n",
      "[GCC 9.4.0]\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import time\n",
    "\n",
    "import numpy\n",
    "from sklearn.datasets import fetch_california_housing\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.utils import resample\n",
    "\n",
    "import concrete.ml\n",
    "from concrete.ml.sklearn import DecisionTreeRegressor as ConcreteDecisionTreeRegressor\n",
    "\n",
    "print(f\"Using ConcreteML version {concrete.ml.version.__version__}\")\n",
    "print(f\"With Python version {sys.version}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fab56356",
   "metadata": {},
   "outputs": [],
   "source": [
    "features_all, target_all = fetch_california_housing(return_X_y=True)\n",
    "features, target = resample(features_all, target_all, replace=True, n_samples=6000, random_state=42)\n",
    "\n",
    "# Split data in train-test groups\n",
    "x_train, x_test, y_train, y_test = train_test_split(\n",
    "    features,\n",
    "    target,\n",
    "    test_size=0.15,\n",
    "    random_state=42,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "09a05cb8",
   "metadata": {},
   "source": [
    "We provide also a quick visualisation of our target variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1d9b9717",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAX1ElEQVR4nO3db2yV9dnA8auF0DqhRURLwGqnbhJ10EmhY5t/tlWJachItoQQI6wxvhkQTWMibIbKnGm3OIMZBJ2ZM3EhsC3TJeIwrhkasxqwjEzdNJsZAcUW2JIWa1ZM2+eFz1OfDoocLFy0/XySk9ib+891zgv7ze+cc7doYGBgIAAAkhRnDwAAjG9iBABIJUYAgFRiBABIJUYAgFRiBABIJUYAgFRiBABINTF7gFPR398fBw8ejClTpkRRUVH2OADAKRgYGIijR4/GzJkzo7h4+PWPUREjBw8ejMrKyuwxAIDTcODAgbjkkkuG/fdRESNTpkyJiI+eTFlZWfI0AMCp6O7ujsrKysHf48MZFTHyf2/NlJWViREAGGU+6SMWPsAKAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAqonZA8CZVrVme8p197XUp1wXYLSxMgIApBIjAEAqMQIApBIjAEAqMQIApBIjAEAqMQIApBIjAEAqMQIApBIjAEAqMQIApBIjAEAqMQIApBIjAEAqMQIApBIjAEAqMQIApBIjAEAqMQIApBIjAEAqMQIApBIjAEAqMQIApBIjAECq04qRTZs2RVVVVZSWlkZtbW3s2rXrlI7bunVrFBUVxZIlS07nsgDAGFRwjGzbti0aGxujqakp9uzZE3Pnzo1FixbFoUOHTnrcvn374p577onrr7/+tIcFAMaegmPk4YcfjjvvvDMaGhri6quvjkcffTQ+85nPxBNPPDHsMX19fXHbbbfF+vXr4/LLL/9UAwMAY0tBMXLs2LFob2+Purq6j09QXBx1dXXR1tY27HE/+MEP4uKLL4477rjj9CcFAMakiYXsfOTIkejr64uKiooh2ysqKuLNN9884TEvv/xy/PznP4+9e/ee8nV6e3ujt7d38Ofu7u5CxgQARpGCYqRQR48ejdtvvz0ef/zxmD59+ikf19zcHOvXrz+Dk41fVWu2p1x3X0t9ynUBOPcVFCPTp0+PCRMmRGdn55DtnZ2dMWPGjOP2f/vtt2Pfvn2xePHiwW39/f0fXXjixHjrrbfiiiuuOO64tWvXRmNj4+DP3d3dUVlZWcioAMAoUVCMTJo0KebNmxetra2DX8/t7++P1tbWWLVq1XH7z549O1577bUh2+677744evRoPPLII8MGRklJSZSUlBQyGue4rBUZAM59Bb9N09jYGCtWrIiamppYsGBBbNiwIXp6eqKhoSEiIpYvXx6zZs2K5ubmKC0tjWuvvXbI8VOnTo2IOG47ADA+FRwjS5cujcOHD8e6deuio6MjqqurY8eOHYMfat2/f38UF7uxKwBwaooGBgYGsof4JN3d3VFeXh5dXV1RVlaWPc6o5u2Ss8eHdoHx7lR/f1vCAABSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSTcweAMaqqjXbU667r6U+5boAp8vKCACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQyt+mSZL1d0sA4FxjZQQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUpxUjmzZtiqqqqigtLY3a2trYtWvXsPv+9re/jZqampg6dWqcf/75UV1dHU899dRpDwwAjC0Fx8i2bduisbExmpqaYs+ePTF37txYtGhRHDp06IT7T5s2Lb7//e9HW1tb/OUvf4mGhoZoaGiI559//lMPDwCMfkUDAwMDhRxQW1sb8+fPj40bN0ZERH9/f1RWVsbq1atjzZo1p3SO6667Lurr6+OBBx44pf27u7ujvLw8urq6oqysrJBxz1luB8+Zsq+lPnsEgIg49d/fBa2MHDt2LNrb26Ouru7jExQXR11dXbS1tX3i8QMDA9Ha2hpvvfVW3HDDDcPu19vbG93d3UMeAMDYVFCMHDlyJPr6+qKiomLI9oqKiujo6Bj2uK6urpg8eXJMmjQp6uvr46c//WncfPPNw+7f3Nwc5eXlg4/KyspCxgQARpGz8m2aKVOmxN69e2P37t3x4IMPRmNjY+zcuXPY/deuXRtdXV2DjwMHDpyNMQGABBML2Xn69OkxYcKE6OzsHLK9s7MzZsyYMexxxcXFceWVV0ZERHV1dfztb3+L5ubmuOmmm064f0lJSZSUlBQyGgAwShW0MjJp0qSYN29etLa2Dm7r7++P1tbWWLhw4Smfp7+/P3p7ewu5NAAwRhW0MhIR0djYGCtWrIiamppYsGBBbNiwIXp6eqKhoSEiIpYvXx6zZs2K5ubmiPjo8x81NTVxxRVXRG9vbzz33HPx1FNPxebNm0f2mQARkfdNLd/iAU5XwTGydOnSOHz4cKxbty46Ojqiuro6duzYMfih1v3790dx8ccLLj09PfHd73433nnnnTjvvPNi9uzZ8ctf/jKWLl06cs8CABi1Cr7PSAb3GYFzn5UR4L+dkfuMAACMNDECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKSamD0AMDZUrdmect19LfUp1wVGjpURACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACDVacXIpk2boqqqKkpLS6O2tjZ27do17L6PP/54XH/99XHBBRfEBRdcEHV1dSfdHwAYXwqOkW3btkVjY2M0NTXFnj17Yu7cubFo0aI4dOjQCfffuXNnLFu2LP74xz9GW1tbVFZWxi233BLvvvvupx4eABj9igYGBgYKOaC2tjbmz58fGzdujIiI/v7+qKysjNWrV8eaNWs+8fi+vr644IILYuPGjbF8+fJTumZ3d3eUl5dHV1dXlJWVFTLuOatqzfbsEWBM2NdSnz0CMIxT/f1d0MrIsWPHor29Perq6j4+QXFx1NXVRVtb2ymd44MPPogPP/wwpk2bNuw+vb290d3dPeQBAIxNBcXIkSNHoq+vLyoqKoZsr6ioiI6OjlM6x7333hszZ84cEjT/rbm5OcrLywcflZWVhYwJAIwiZ/XbNC0tLbF169Z4+umno7S0dNj91q5dG11dXYOPAwcOnMUpAYCzaWIhO0+fPj0mTJgQnZ2dQ7Z3dnbGjBkzTnrsQw89FC0tLfGHP/wh5syZc9J9S0pKoqSkpJDRAIBRqqCVkUmTJsW8efOitbV1cFt/f3+0trbGwoULhz3uxz/+cTzwwAOxY8eOqKmpOf1pAYAxp6CVkYiIxsbGWLFiRdTU1MSCBQtiw4YN0dPTEw0NDRERsXz58pg1a1Y0NzdHRMSPfvSjWLduXWzZsiWqqqoGP1syefLkmDx58gg+FQBgNCo4RpYuXRqHDx+OdevWRUdHR1RXV8eOHTsGP9S6f//+KC7+eMFl8+bNcezYsfj2t7895DxNTU1x//33f7rpAYBRr+D7jGRwnxFgOO4zAueuM3KfEQCAkSZGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASFXw7eABgDMj6+7c2XcytjICAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAqonZAwB8GlVrtqdcd19Lfcp1YSyyMgIApBIjAEAqMQIApBIjAEAqMQIApBIjAEAqMQIApBIjAEAqMQIApHIHVoDT4M6vMHKsjAAAqcQIAJBKjAAAqcQIAJDqtGJk06ZNUVVVFaWlpVFbWxu7du0adt833ngjvvWtb0VVVVUUFRXFhg0bTndWAGAMKjhGtm3bFo2NjdHU1BR79uyJuXPnxqJFi+LQoUMn3P+DDz6Iyy+/PFpaWmLGjBmfemAAYGwp+Ku9Dz/8cNx5553R0NAQERGPPvpobN++PZ544olYs2bNcfvPnz8/5s+fHxFxwn8H4NT5SjFjUUErI8eOHYv29vaoq6v7+ATFxVFXVxdtbW0jPhwAMPYVtDJy5MiR6Ovri4qKiiHbKyoq4s033xyxoXp7e6O3t3fw5+7u7hE7NwBwbjknv03T3Nwc5eXlg4/KysrskQCAM6SgGJk+fXpMmDAhOjs7h2zv7Owc0Q+nrl27Nrq6ugYfBw4cGLFzAwDnloJiZNKkSTFv3rxobW0d3Nbf3x+tra2xcOHCERuqpKQkysrKhjwAgLGp4G/TNDY2xooVK6KmpiYWLFgQGzZsiJ6ensFv1yxfvjxmzZoVzc3NEfHRh17/+te/Dv73u+++G3v37o3JkyfHlVdeOYJPBQAYjQqOkaVLl8bhw4dj3bp10dHREdXV1bFjx47BD7Xu378/ios/XnA5ePBgfPGLXxz8+aGHHoqHHnoobrzxxti5c+enfwYAwKhWcIxERKxatSpWrVp1wn/778CoqqqKgYGB07kMADAOnJPfpgEAxg8xAgCkEiMAQKrT+szIWJL1dx4AgI9YGQEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUk3MHgCAc1/Vmu0p193XUp9yXc4uKyMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQKqJ2QMAwHCq1mxPue6+lvqU645XVkYAgFRiBABIJUYAgFRiBABIJUYAgFS+TQMA/yXrWzzjlZURACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUokRACCVGAEAUp1WjGzatCmqqqqitLQ0amtrY9euXSfd/9e//nXMnj07SktL4wtf+EI899xzpzUsADD2FBwj27Zti8bGxmhqaoo9e/bE3LlzY9GiRXHo0KET7v+nP/0pli1bFnfccUf8+c9/jiVLlsSSJUvi9ddf/9TDAwCjX9HAwMBAIQfU1tbG/PnzY+PGjRER0d/fH5WVlbF69epYs2bNcfsvXbo0enp64tlnnx3c9qUvfSmqq6vj0UcfPaVrdnd3R3l5eXR1dUVZWVkh436iqjXbR/R8ADDa7GupPyPnPdXf3xMLOemxY8eivb091q5dO7ituLg46urqoq2t7YTHtLW1RWNj45BtixYtimeeeWbY6/T29kZvb+/gz11dXRHx0ZMaaf29H4z4OQFgNDkTv1///3k/ad2joBg5cuRI9PX1RUVFxZDtFRUV8eabb57wmI6OjhPu39HRMex1mpubY/369cdtr6ysLGRcAOAUlG84s+c/evRolJeXD/vvBcXI2bJ27dohqyn9/f3x73//Oy688MIoKipKnGz86O7ujsrKyjhw4MCIvzXGyXnt83jt83jt85zJ135gYCCOHj0aM2fOPOl+BcXI9OnTY8KECdHZ2Tlke2dnZ8yYMeOEx8yYMaOg/SMiSkpKoqSkZMi2qVOnFjIqI6SsrMz/GJJ47fN47fN47fOcqdf+ZCsi/6egb9NMmjQp5s2bF62trYPb+vv7o7W1NRYuXHjCYxYuXDhk/4iIF154Ydj9AYDxpeC3aRobG2PFihVRU1MTCxYsiA0bNkRPT080NDRERMTy5ctj1qxZ0dzcHBERd911V9x4443xk5/8JOrr62Pr1q3x6quvxs9+9rORfSYAwKhUcIwsXbo0Dh8+HOvWrYuOjo6orq6OHTt2DH5Idf/+/VFc/PGCy5e//OXYsmVL3HffffG9730vPve5z8UzzzwT11577cg9C0ZcSUlJNDU1Hfd2GWee1z6P1z6P1z7PufDaF3yfEQCAkeRv0wAAqcQIAJBKjAAAqcQIAJBKjHCcl156KRYvXhwzZ86MoqKik/4dIUZOc3NzzJ8/P6ZMmRIXX3xxLFmyJN56663sscaFzZs3x5w5cwZv+rRw4cL4/e9/nz3WuNTS0hJFRUVx9913Z48y5t1///1RVFQ05DF79uyUWcQIx+np6Ym5c+fGpk2bskcZV1588cVYuXJlvPLKK/HCCy/Ehx9+GLfcckv09PRkjzbmXXLJJdHS0hLt7e3x6quvxte//vX45je/GW+88Ub2aOPK7t2747HHHos5c+ZkjzJuXHPNNfHee+8NPl5++eWUOc7Jv01DrltvvTVuvfXW7DHGnR07dgz5+cknn4yLL7442tvb44YbbkiaanxYvHjxkJ8ffPDB2Lx5c7zyyitxzTXXJE01vrz//vtx2223xeOPPx4//OEPs8cZNyZOnHjSP89ytlgZgXNUV1dXRERMmzYteZLxpa+vL7Zu3Ro9PT3+bMVZtHLlyqivr4+6urrsUcaVv//97zFz5sy4/PLL47bbbov9+/enzGFlBM5B/f39cffdd8dXvvIVdys+S1577bVYuHBh/Oc//4nJkyfH008/HVdffXX2WOPC1q1bY8+ePbF79+7sUcaV2traePLJJ+Oqq66K9957L9avXx/XX399vP766zFlypSzOosYgXPQypUr4/XXX097/3Y8uuqqq2Lv3r3R1dUVv/nNb2LFihXx4osvCpIz7MCBA3HXXXfFCy+8EKWlpdnjjCv//+34OXPmRG1tbVx22WXxq1/9Ku64446zOosYgXPMqlWr4tlnn42XXnopLrnkkuxxxo1JkybFlVdeGRER8+bNi927d8cjjzwSjz32WPJkY1t7e3scOnQorrvuusFtfX198dJLL8XGjRujt7c3JkyYkDjh+DF16tT4/Oc/H//4xz/O+rXFCJwjBgYGYvXq1fH000/Hzp0747Of/Wz2SONaf39/9Pb2Zo8x5n3jG9+I1157bci2hoaGmD17dtx7771C5Cx6//334+23347bb7/9rF9bjHCc999/f0gZ//Of/4y9e/fGtGnT4tJLL02cbGxbuXJlbNmyJX73u9/FlClToqOjIyIiysvL47zzzkuebmxbu3Zt3HrrrXHppZfG0aNHY8uWLbFz5854/vnns0cb86ZMmXLc56LOP//8uPDCC31e6gy75557YvHixXHZZZfFwYMHo6mpKSZMmBDLli0767OIEY7z6quvxte+9rXBnxsbGyMiYsWKFfHkk08mTTX2bd68OSIibrrppiHbf/GLX8R3vvOdsz/QOHLo0KFYvnx5vPfee1FeXh5z5syJ559/Pm6++ebs0eCMeeedd2LZsmXxr3/9Ky666KL46le/Gq+88kpcdNFFZ32WooGBgYGzflUAgP/lPiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCk+h8Gio//WptaJgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.hist(target, bins=15, density=True)\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f7141309",
   "metadata": {},
   "source": [
    "## Measuring Accuracy\n",
    "\n",
    "We choose the mean absolute error as a mesure of accuracy for our models, which has the advantage of preserving the units of measurement, here in dollars.\n",
    "\n",
    "We will compare our decision trees' to a simpler model (a canary model) to benchmark accuracy. \n",
    "We could use the median of the target, but that would perhaps be too basic.\n",
    "Instead, we expect income to be somewhat proportional to house value, and therefore we use a univariate linear regression model as our canary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a135fe2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean Absolute Overall Error : 89922.83$\n"
     ]
    }
   ],
   "source": [
    "# Utility functions\n",
    "\n",
    "\n",
    "def print_as_dollars(x):\n",
    "    \"\"\"Prints the value * 100'000$\"\"\"\n",
    "    return f\"{x * 10**5:.2f}$\"\n",
    "\n",
    "\n",
    "def print_compare_to_baseline(x, baseline_error):\n",
    "    \"\"\"Prints percentage improvement over baseline\"\"\"\n",
    "    return f\"{(x - baseline_error) / baseline_error * 100 :.2f}% of baseline\"\n",
    "\n",
    "\n",
    "mean_error = mean_absolute_error(y_test, numpy.repeat([numpy.median(y_test)], y_test.shape))\n",
    "print(f\"Mean Absolute Overall Error : {print_as_dollars(mean_error)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e1f2990f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Baseline Mean Error : 62719.65$\n"
     ]
    }
   ],
   "source": [
    "canary = LinearRegression()\n",
    "canary.fit(x_train[:, :1], y_train)\n",
    "baseline_error = mean_absolute_error(canary.predict(x_test[:, :1]), y_test)\n",
    "print(f\"Baseline Mean Error : {print_as_dollars(baseline_error)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "2e5babd9",
   "metadata": {},
   "source": [
    "## Training A Decision Tree\n",
    "\n",
    "ConcreteDecisionTreeRegressor is the Concrete ML equivalent of scikit-learn's DecisionTreeRegressor.\n",
    "It supports the same parameters and a similar interface, with the extra capability of predicting directly on ciphertext without the need to decipher it, thus preservacy privacy.\n",
    "\n",
    "Currently, Concrete ML models must be trained on plaintext. To see how it works, we train a DecisionTreeRegressor with default parameters and estimate its accuracy on test data. Note here that predictions are done on plaintext too, but soon, we will predict on ciphertext."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8069097d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training on 5100 samples in 21.5514 seconds\n"
     ]
    }
   ],
   "source": [
    "default_model = ConcreteDecisionTreeRegressor(criterion=\"absolute_error\", n_bits=6, random_state=42)\n",
    "\n",
    "begin = time.time()\n",
    "default_model.fit(x_train, y_train)\n",
    "print(f\"Training on {x_train.shape[0]} samples in {(time.time() - begin):.4f} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e286d33c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Default Model Mean Error: 43788.10$,-30.18% of baseline\n"
     ]
    }
   ],
   "source": [
    "default_error = mean_absolute_error(default_model.predict(x_test), y_test)\n",
    "print(\n",
    "    f\"Default Model Mean Error: {print_as_dollars(default_error)},\"\n",
    "    f\"{print_compare_to_baseline(default_error, baseline_error)}\"\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6aa2dad5",
   "metadata": {},
   "source": [
    "## Optimising Hyper Parameters\n",
    "Working on plaintext is considerably faster than on ciphertext. \n",
    "We take this opportunity to search for optimised hyper-parameters for our dataset.\n",
    "Here we use a GridSearch strategy for simplicity, but could have used any method of our liking."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3b076c44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best hyper parameters: {'criterion': 'absolute_error', 'max_depth': 10, 'max_features': 5, 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_bits': 7, 'random_state': 42}\n",
      "Min lost: 43828.16$\n"
     ]
    }
   ],
   "source": [
    "# Find best hyper parameters with cross validation\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "# List of hyper parameters to tune\n",
    "param_grid = {\n",
    "    \"criterion\": [\"absolute_error\"],\n",
    "    \"random_state\": [42],\n",
    "    \"max_depth\": [10],\n",
    "    \"n_bits\": [6, 7],\n",
    "    \"max_features\": [2, 5],\n",
    "    \"min_samples_leaf\": [2, 5],\n",
    "    \"min_samples_split\": [2, 10],\n",
    "}\n",
    "\n",
    "grid_search = GridSearchCV(\n",
    "    ConcreteDecisionTreeRegressor(),\n",
    "    param_grid,\n",
    "    cv=3,\n",
    "    scoring=\"neg_mean_absolute_error\",\n",
    "    error_score=\"raise\",\n",
    "    n_jobs=1,\n",
    ")\n",
    "\n",
    "gs_results = grid_search.fit(x_train, y_train)\n",
    "print(\"Best hyper parameters:\", gs_results.best_params_)\n",
    "print(f\"Min lost: {print_as_dollars(-gs_results.best_score_)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "874adb94-5849-49b9-95ec-e5782736dfb1",
   "metadata": {},
   "source": [
    "### Quantization\n",
    "You might have noticed we used a hyper-parameter that has no scikit-learn equivalent: `n_bits`. \n",
    "This is a first specificity of homomorphic encryption: values must first be represented as (not too big) integers before encryption.\n",
    "This encoding is called quantization and `n_bits` is related to the maximum size of the quantized values.\n",
    "Put it simply, lower `n_bits` means that quantization is less precise, but FHE computations are faster.\n",
    "For more details, see [quantization](../explanations/quantization.md) from the official documentation.\n",
    "\n",
    "Our model might or not gain from extra precision and/or efficiency. There is a balance to strike for the model between the two and as usual, the right balance depends on context.\n",
    "For now, we observe that our models performance increases with precision, that is with higher `n_bits`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1e64ee0b-76af-4e6a-a1d7-31048115a6e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for n_bits=6 is 45585.76$\n",
      "Error for n_bits=7 is 43828.16$\n"
     ]
    }
   ],
   "source": [
    "# We fix all parameters as the best ones, except for n_bits.\n",
    "best = gs_results.best_params_\n",
    "cv_errors = [\n",
    "    {\"n_bits\": params[\"n_bits\"], \"score\": score}\n",
    "    for params, score in zip(\n",
    "        gs_results.cv_results_[\"params\"], gs_results.cv_results_[\"mean_test_score\"]\n",
    "    )\n",
    "    if (params[\"max_depth\"] == best[\"max_depth\"])\n",
    "    and (params[\"max_features\"] == best[\"max_features\"])  # noqa: W503\n",
    "    and (params[\"min_samples_leaf\"] == best[\"min_samples_leaf\"])  # noqa: W503\n",
    "    and (params[\"min_samples_split\"] == best[\"min_samples_split\"])  # noqa: W503\n",
    "]\n",
    "for el in cv_errors:\n",
    "    print(f\"Error for n_bits={el['n_bits']} is {print_as_dollars(-el['score'])}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "0ae45c2a",
   "metadata": {},
   "source": [
    "## Training the Optimised Model\n",
    "We now train our model with the optimised hyper-parameters.\n",
    "Firstly, to show how it works, let us fix `n_bits` to a lower value so that compute time remains reasonable.\n",
    "Once done, we will then see how higher values of `n_bits` affect compute time.\n",
    "We also use `fit_benchmark` method instead of `fit` to easily compare our model to its scikit-learn equivalent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "20431f4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the model with best hyper parameters\n",
    "model = ConcreteDecisionTreeRegressor(\n",
    "    max_depth=gs_results.best_params_[\"max_depth\"],\n",
    "    max_features=gs_results.best_params_[\"max_features\"],\n",
    "    min_samples_leaf=gs_results.best_params_[\"min_samples_leaf\"],\n",
    "    min_samples_split=gs_results.best_params_[\"min_samples_split\"],\n",
    "    n_bits=6,\n",
    "    random_state=42,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "07e5cffc",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, sklearn_model = model.fit_benchmark(x_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "177e073d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sklearn  Mean Error: 45960.28$,-26.72% of baseline\n",
      "Concrete Mean Error: 46691.16$,-25.56% of baseline\n"
     ]
    }
   ],
   "source": [
    "# Compute average precision on test\n",
    "y_pred_concrete = model.predict(x_test)\n",
    "y_pred_sklearn = sklearn_model.predict(x_test)\n",
    "concrete_average_precision = mean_absolute_error(y_test, y_pred_concrete)\n",
    "sklearn_average_precision = mean_absolute_error(y_test, y_pred_sklearn)\n",
    "print(\n",
    "    f\"Sklearn  Mean Error: {print_as_dollars(sklearn_average_precision)},\"\n",
    "    f\"{print_compare_to_baseline(sklearn_average_precision, baseline_error)}\"\n",
    ")\n",
    "print(\n",
    "    f\"Concrete Mean Error: {print_as_dollars(concrete_average_precision)},\"\n",
    "    f\"{print_compare_to_baseline(concrete_average_precision, baseline_error)}\"\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d122d719",
   "metadata": {},
   "source": [
    "We see that the concrete model has similar performance to the scikit-learn model.\n",
    "However, there can be a difference as concrete models perform a possibly lossy quantization of the data.\n",
    "We should expect in general the accuracy to be slightly lower for concrete models, but this is not always the case as models are themselves approximations."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8c34e664",
   "metadata": {},
   "source": [
    "## Predicting on Ciphertext\n",
    "If the predictions are similar although slightly less accurate, the real advantage of ConcreteML is privacy.\n",
    "We now show how we can perform prediction on ciphertext with Concrete ML, so that the model does not need to decipher the data at all to compute its estimate."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "88c7288f",
   "metadata": {},
   "source": [
    "### Setup\n",
    "In order to prepare our model to work directly on ciphertext, we need to:\n",
    "\n",
    "  - Compile the model to a circuit: the circuit represents our model's prediction program in FHE operations.\n",
    "  - Generate an encryption key for our circuit.\n",
    "  \n",
    "When deploying the model as a service in a client-server architecture, the server compiles the model and serves it, where the client generates the key and keeps it secret. The client uses its secret key to first encrypt then send data to the server.\n",
    "\n",
    "Note that in order to compile the circuit, the compiler needs to have some examples of the data it will be using.\n",
    "We use here all our training data as a sample, increasing the compile time, but making it safer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6ca2fd2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compiled with 500 samples in 1.4179 seconds\n"
     ]
    }
   ],
   "source": [
    "from concrete.compiler import check_gpu_available\n",
    "\n",
    "use_gpu_if_available = False\n",
    "device = \"cuda\" if use_gpu_if_available and check_gpu_available() else \"cpu\"\n",
    "\n",
    "x_train_subset = x_train[:500]\n",
    "\n",
    "begin = time.time()\n",
    "circuit = model.compile(x_train_subset, device=device)\n",
    "print(f\"Compiled with {len(x_train_subset)} samples in {(time.time() - begin):.4f} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "132936f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating a key for an 8-bit circuit\n",
      "Key generation time: 0.84 seconds\n"
     ]
    }
   ],
   "source": [
    "print(f\"Generating a key for an {circuit.graph.maximum_integer_bit_width()}-bit circuit\")\n",
    "time_begin = time.time()\n",
    "circuit.client.keygen(force=False)\n",
    "print(f\"Key generation time: {time.time() - time_begin:.2f} seconds\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4c33be54",
   "metadata": {},
   "source": [
    "### Prediction \n",
    "Prediction can be done directly on ciphertext by passing the fhe=\"execute\" argument to the predict method.\n",
    "Note however that computations on ciphertext are slower than on plaintext.\n",
    "In order to get results quickly for this tutorial, we show predictions only on a few observations of the test sample."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "cb045438",
   "metadata": {},
   "outputs": [],
   "source": [
    "FHE_SAMPLES = 3\n",
    "x_test_small = x_test[:FHE_SAMPLES]\n",
    "y_pred = y_test[:FHE_SAMPLES]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "904f585b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Execution time: 1.54 seconds per sample\n"
     ]
    }
   ],
   "source": [
    "# Predict in FHE for a few examples\n",
    "time_begin = time.time()\n",
    "y_pred_fhe = model.predict(x_test_small, fhe=\"execute\")\n",
    "print(f\"Execution time: {(time.time() - time_begin) / FHE_SAMPLES:.2f} seconds per sample\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "43adb59d",
   "metadata": {},
   "source": [
    "We have previously obtained predictions from plaintext.\n",
    "We now show that the estimates from plaintext or ciphertext in our setting are the same.\n",
    "In general, they might slightly differ due to the addition of noise in ciphertext."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "07fe03ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cipher estimates:\n",
      "293651.38$, 277778.33$, 126984.38$\n",
      "Plain estimates:\n",
      "290800.00$, 214600.00$, 102800.00$\n",
      "Differences:\n",
      "2851.38$, 63178.33$, 24184.38$\n"
     ]
    }
   ],
   "source": [
    "# Check prediction FHE vs sklearn\n",
    "print(\"Cipher estimates:\")\n",
    "print(f\"{', '.join(f'{print_as_dollars(x)}' for x in y_pred_fhe)}\")\n",
    "print(\"Plain estimates:\")\n",
    "print(f\"{', '.join(f'{print_as_dollars(x)}' for x in y_pred)}\")\n",
    "print(\"Differences:\")\n",
    "print(f\"{', '.join(f'{print_as_dollars(x)}' for x in (y_pred_fhe - y_pred))}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "89c294cd-f03d-4fda-90f4-629320802d7c",
   "metadata": {},
   "source": [
    "### Compute Times\n",
    "We finish this tutorial with an evaluation of accuracy and compute times as we increase `n_bits` parameters.\n",
    "If you wish to recompute this cell, please note that it takes a long time to complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "15008ad0-05d0-4ec2-91ea-c4be2efbc05b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "N_BITS = 6\n",
      "----------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sklearn  Mean Error: 45960.28$,-26.72% of baseline\n",
      "Concrete Mean Error: 46691.16$,-25.56% of baseline\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Circuit compiled with 500 samples in 1.4455 seconds\n",
      "Generating a key for an 8-bit circuit\n",
      "Key generation time: 0.82 seconds\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Execution time: 1.64 seconds per sample\n",
      "\n",
      "N_BITS = 7\n",
      "----------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sklearn  Mean Error: 45960.28$,-26.72% of baseline\n",
      "Concrete Mean Error: 43046.38$,-31.37% of baseline\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Circuit compiled with 500 samples in 1.3987 seconds\n",
      "Generating a key for an 9-bit circuit\n",
      "Key generation time: 0.58 seconds\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Execution time: 1.40 seconds per sample\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Concatenate all the steps in one function of n_bits\n",
    "\n",
    "\n",
    "def evaluate(n_bits):\n",
    "    model = ConcreteDecisionTreeRegressor(\n",
    "        max_depth=gs_results.best_params_[\"max_depth\"],\n",
    "        max_features=gs_results.best_params_[\"max_features\"],\n",
    "        min_samples_leaf=gs_results.best_params_[\"min_samples_leaf\"],\n",
    "        min_samples_split=gs_results.best_params_[\"min_samples_split\"],\n",
    "        n_bits=n_bits,\n",
    "        random_state=42,\n",
    "    )\n",
    "\n",
    "    model, sklearn_model = model.fit_benchmark(x_train, y_train)\n",
    "\n",
    "    y_pred_concrete = model.predict(x_test)\n",
    "    y_pred_sklearn = sklearn_model.predict(x_test)\n",
    "\n",
    "    concrete_average_precision = mean_absolute_error(y_test, y_pred_concrete)\n",
    "    sklearn_average_precision = mean_absolute_error(y_test, y_pred_sklearn)\n",
    "\n",
    "    print(\n",
    "        f\"Sklearn  Mean Error: {print_as_dollars(sklearn_average_precision)},\"\n",
    "        f\"{print_compare_to_baseline(sklearn_average_precision, baseline_error)}\"\n",
    "    )\n",
    "    print(\n",
    "        f\"Concrete Mean Error: {print_as_dollars(concrete_average_precision)},\"\n",
    "        f\"{print_compare_to_baseline(concrete_average_precision, baseline_error)}\"\n",
    "    )\n",
    "\n",
    "    x_train_subset = x_train[:500]\n",
    "    begin = time.time()\n",
    "    circuit = model.compile(x_train_subset)\n",
    "    print(\n",
    "        f\"Circuit compiled with {len(x_train_subset)} samples in {(time.time() - begin):.4f} \"\n",
    "        \"seconds\"\n",
    "    )\n",
    "    print(f\"Generating a key for an {circuit.graph.maximum_integer_bit_width()}-bit circuit\")\n",
    "\n",
    "    time_begin = time.time()\n",
    "    circuit.client.keygen(force=False)\n",
    "    print(f\"Key generation time: {time.time() - time_begin:.2f} seconds\")\n",
    "\n",
    "    time_begin = time.time()\n",
    "    model.predict(x_test_small, fhe=\"execute\")\n",
    "    print(f\"Execution time: {(time.time() - time_begin) / FHE_SAMPLES:.2f} seconds per sample\")\n",
    "\n",
    "\n",
    "for n_bits in [6, 7]:\n",
    "    header = f\"N_BITS = {n_bits}\"\n",
    "    print(header)\n",
    "    print(\"-\" * len(header))\n",
    "    evaluate(n_bits)\n",
    "    print()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "0e38dfe7-8e43-4828-9ec5-ac0f2d88403a",
   "metadata": {},
   "source": [
    "We see that increasing `n_bits` can increase dramatically the computation times making them prohibitive.\n",
    "Yet better quantization precision does not necessarily lower the generalization error of the model, as estimates are already appproximations in some sense."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "770e2842",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "Using a decision tree regressor in concrete-ML is very similar to using it through scikit-learn and can be done in a few lines of code.\n",
    "Data-scientists can now use concrete-ML to perform privacy-preserving decision tree regression on encrypted data.\n",
    "\n",
    "### Going Further\n",
    "Some additional tools can smooth up the development workflow:\n",
    "\n",
    "  - Selecting relevant bit-size for [quantizing](../explanations/quantization.md) the model.\n",
    "  - Alleviating the [compilation](../explanations/compilation.md) time by making use of [FHE simulation](../explanations/compilation.md#fhe-simulation)\n",
    "\n",
    "Once the model is carefully trained and quantized, it is ready to be deployed and used in production. Here are some useful links on the subject:\n",
    "    \n",
    "  - [Inference in the Cloud](../getting-started/cloud.md) summarize the steps for cloud deployment\n",
    "  - [Production Deployment](../guides/client_server.md) offers a high-level view of how to deploy a Concrete ML model in a client/server setting.\n",
    "  - [Client Server in Concrete ML](./ClientServer.ipynb) provides a more hands-on approach as another tutorial."
   ]
  }
 ],
 "metadata": {
  "execution": {
   "timeout": 10800
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
